#ifndef HV_TCP_SERVER_HPP_
#define HV_TCP_SERVER_HPP_

#include "hv/hsocket.h"
#include "hv/hssl.h"
#include "hv/hlog.h"

#include "hv/EventLoopThreadPool.h"
#include "hv/Channel.h"

namespace hv {

    template<class TSocketChannel = SocketChannel>
    class TcpServerEventLoopTmpl {
    public:
        typedef std::shared_ptr<TSocketChannel> TSocketChannelPtr;

        TcpServerEventLoopTmpl(EventLoopPtr loop = nullptr) {
            acceptor_loop = loop ? loop : std::make_shared<EventLoop>();
            port = 0;
            listenfd = -1;
            tls = false;
            tls_setting = NULL;
            unpack_setting = NULL;
            max_connections = 0xFFFFFFFF;
            load_balance = LB_RoundRobin;
        }

        virtual ~TcpServerEventLoopTmpl() {
            HV_FREE(tls_setting);
            HV_FREE(unpack_setting);
        }

        EventLoopPtr loop(int idx = -1) {
            EventLoopPtr worker_loop = worker_threads.loop(idx);
            if (worker_loop == NULL) {
                worker_loop = acceptor_loop;
            }
            return worker_loop;
        }

        //@retval >=0 listenfd, <0 error
        int createsocket(int port, const char* host = "0.0.0.0") {
            listenfd = Listen(port, host);
            if (listenfd < 0) return listenfd;
            this->host = host;
            this->port = port;
            return listenfd;
        }
        // closesocket thread-safe
        void closesocket() {
            if (listenfd >= 0) {
                hloop_t* loop = acceptor_loop->loop();
                if (loop) {
                    hio_t* listenio = hio_get(loop, listenfd);
                    assert(listenio != NULL);
                    hio_close_async(listenio);
                }
                listenfd = -1;
            }
        }

        void setMaxConnectionNum(uint32_t num) {
            max_connections = num;
        }

        void setLoadBalance(load_balance_e lb) {
            load_balance = lb;
        }

        // NOTE: totalThreadNum = 1 acceptor_thread + N worker_threads (N can be 0)
        void setThreadNum(int num) {
            worker_threads.setThreadNum(num);
        }

        int startAccept() {
            if (listenfd < 0) {
                listenfd = createsocket(port, host.c_str());
                if (listenfd < 0) {
                    hloge("createsocket %s:%d return %d!\n", host.c_str(), port, listenfd);
                    return listenfd;
                }
            }
            hloop_t* loop = acceptor_loop->loop();
            if (loop == NULL) return -2;
            hio_t* listenio = haccept(loop, listenfd, onAccept);
            assert(listenio != NULL);
            hevent_set_userdata(listenio, this);
            if (tls) {
                hio_enable_ssl(listenio);
                if (tls_setting) {
                    int ret = hio_new_ssl_ctx(listenio, tls_setting);
                    if (ret != 0) {
                        hloge("new SSL_CTX failed: %d", ret);
                        closesocket();
                        return ret;
                    }
                }
            }
            return 0;
        }

        int stopAccept() {
            if (listenfd < 0) return -1;
            hloop_t* loop = acceptor_loop->loop();
            if (loop == NULL) return -2;
            hio_t* listenio = hio_get(loop, listenfd);
            assert(listenio != NULL);
            return hio_del(listenio, HV_READ);
        }

        // start thread-safe
        void start(bool wait_threads_started = true) {
            if (worker_threads.threadNum() > 0) {
                worker_threads.start(wait_threads_started);
            }
            acceptor_loop->runInLoop(std::bind(&TcpServerEventLoopTmpl::startAccept, this));
        }
        // stop thread-safe
        void stop(bool wait_threads_stopped = true) {
            closesocket();
            if (worker_threads.threadNum() > 0) {
                worker_threads.stop(wait_threads_stopped);
            }
        }

        int withTLS(hssl_ctx_opt_t* opt = NULL) {
            tls = true;
            if (opt) {
                if (tls_setting == NULL) {
                    HV_ALLOC_SIZEOF(tls_setting);
                }
                opt->endpoint = HSSL_SERVER;
                *tls_setting = *opt;
            }
            return 0;
        }

        void setUnpack(unpack_setting_t* setting) {
            if (setting == NULL) {
                HV_FREE(unpack_setting);
                return;
            }
            if (unpack_setting == NULL) {
                HV_ALLOC_SIZEOF(unpack_setting);
            }
            *unpack_setting = *setting;
        }

        // channel
        const TSocketChannelPtr& addChannel(hio_t* io) {
            uint32_t id = hio_id(io);
            auto channel = std::make_shared<TSocketChannel>(io);
            std::lock_guard<std::mutex> locker(mutex_);
            channels[id] = channel;
            return channels[id];
        }

        TSocketChannelPtr getChannelById(uint32_t id) {
            std::lock_guard<std::mutex> locker(mutex_);
            auto iter = channels.find(id);
            return iter != channels.end() ? iter->second : NULL;
        }

        void removeChannel(const TSocketChannelPtr& channel) {
            uint32_t id = channel->id();
            std::lock_guard<std::mutex> locker(mutex_);
            channels.erase(id);
        }

        size_t connectionNum() {
            std::lock_guard<std::mutex> locker(mutex_);
            return channels.size();
        }

        int foreachChannel(std::function<void(const TSocketChannelPtr& channel)> fn) {
            std::lock_guard<std::mutex> locker(mutex_);
            for (auto& pair : channels) {
                fn(pair.second);
            }
            return channels.size();
        }

        // broadcast thread-safe
        int broadcast(const void* data, int size) {
            return foreachChannel([data, size](const TSocketChannelPtr& channel) {
                channel->write(data, size);
            });
        }

        int broadcast(const std::string& str) {
            return broadcast(str.data(), str.size());
        }

    private:
        static void newConnEvent(hio_t* connio) {
            TcpServerEventLoopTmpl* server = (TcpServerEventLoopTmpl*)hevent_userdata(connio);
            if (server->connectionNum() >= server->max_connections) {
                hlogw("over max_connections");
                hio_close(connio);
                return;
            }

            // NOTE: attach to worker loop
            EventLoop* worker_loop = currentThreadEventLoop;
            assert(worker_loop != NULL);
            hio_attach(worker_loop->loop(), connio);

            const TSocketChannelPtr& channel = server->addChannel(connio);
            channel->status = SocketChannel::CONNECTED;

            channel->onread = [server, &channel](Buffer* buf) {
                if (server->onMessage) {
                    server->onMessage(channel, buf);
                }
            };
            channel->onwrite = [server, &channel](Buffer* buf) {
                if (server->onWriteComplete) {
                    server->onWriteComplete(channel, buf);
                }
            };
            channel->onclose = [server, &channel]() {
                EventLoop* worker_loop = currentThreadEventLoop;
                assert(worker_loop != NULL);
                --worker_loop->connectionNum;

                channel->status = SocketChannel::CLOSED;
                if (server->onConnection) {
                    server->onConnection(channel);
                }
                server->removeChannel(channel);
                // NOTE: After removeChannel, channel may be destroyed,
                // so in this lambda function, no code should be added below.
            };

            if (server->unpack_setting) {
                channel->setUnpack(server->unpack_setting);
            }
            channel->startRead();
            if (server->onConnection) {
                server->onConnection(channel);
            }
        }

        static void onAccept(hio_t* connio) {
            TcpServerEventLoopTmpl* server = (TcpServerEventLoopTmpl*)hevent_userdata(connio);
            // NOTE: detach from acceptor loop
            hio_detach(connio);
            EventLoopPtr worker_loop = server->worker_threads.nextLoop(server->load_balance);
            if (worker_loop == NULL) {
                worker_loop = server->acceptor_loop;
            }
            ++worker_loop->connectionNum;
            worker_loop->runInLoop(std::bind(&TcpServerEventLoopTmpl::newConnEvent, connio));
        }

    public:
        std::string             host;
        int                     port;
        int                     listenfd;
        bool                    tls;
        hssl_ctx_opt_t* tls_setting;
        unpack_setting_t* unpack_setting;
        // Callback
        std::function<void(const TSocketChannelPtr&)>           onConnection;
        std::function<void(const TSocketChannelPtr&, Buffer*)>  onMessage;
        // NOTE: Use Channel::isWriteComplete in onWriteComplete callback to determine whether all data has been written.
        std::function<void(const TSocketChannelPtr&, Buffer*)>  onWriteComplete;

        uint32_t                max_connections;
        load_balance_e          load_balance;

    private:
        // id => TSocketChannelPtr
        std::map<uint32_t, TSocketChannelPtr>   channels; // GUAREDE_BY(mutex_)
        std::mutex                              mutex_;

        EventLoopPtr            acceptor_loop;
        EventLoopThreadPool     worker_threads;
    };

    template<class TSocketChannel = SocketChannel>
    class TcpServerTmpl : private EventLoopThread, public TcpServerEventLoopTmpl<TSocketChannel> {
    public:
        TcpServerTmpl(EventLoopPtr loop = NULL)
            : EventLoopThread(loop)
            , TcpServerEventLoopTmpl<TSocketChannel>(EventLoopThread::loop())
            , is_loop_owner(loop == NULL)
        {}
        virtual ~TcpServerTmpl() {
            stop(true);
        }

        EventLoopPtr loop(int idx = -1) {
            return TcpServerEventLoopTmpl<TSocketChannel>::loop(idx);
        }

        // start thread-safe
        void start(bool wait_threads_started = true) {
            TcpServerEventLoopTmpl<TSocketChannel>::start(wait_threads_started);
            EventLoopThread::start(wait_threads_started);
        }

        // stop thread-safe
        void stop(bool wait_threads_stopped = true) {
            if (is_loop_owner) {
                EventLoopThread::stop(wait_threads_stopped);
            }
            TcpServerEventLoopTmpl<TSocketChannel>::stop(wait_threads_stopped);
        }

    private:
        bool is_loop_owner;
    };

    typedef TcpServerTmpl<SocketChannel> TcpServer;

}

#endif // HV_TCP_SERVER_HPP_
